{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# mxnet 图像识别教程\n",
    "\n",
    "\n",
    "这里以 TinyMind 《汉字书法识别》比赛数据为例，展示使用 mxnet 进行图像数据分类模型训练的整个流程。\n",
    "\n",
    "数据地址请参考:\n",
    "https://www.tinymind.cn/competitions/41#property_23\n",
    "\n",
    "或到这里下载：\n",
    "自由练习赛数据下载地址：\n",
    "训练集：链接: https://pan.baidu.com/s/1UxvN7nVpa0cuY1A-0B8gjg 密码: aujd\n",
    "\n",
    "测试集: https://pan.baidu.com/s/1tzMYlrNY4XeMadipLCPzTw 密码: 4y9k"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据探索\n",
    "请参考官方的数据说明"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据处理\n",
    "\n",
    "竞赛中只有训练集 train 数据有准确的标签，因此这里只使用 train 数据即可，实际应用中，阶段 1、2 的榜单都需要使用。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据下载\n",
    "\n",
    "下载数据之后进行解压，得到 train 文件夹，里面有 100 个文件夹，每个文件夹名字即是各个汉字的标签。类似的数据集结构经常在分类任务中见到。可以使用下述命令验证一下每个文件夹下面文件的数量，看数据集是否符合竞赛数据描述：\n",
    "```sh\n",
    "for l in $(ls); do echo $l $(ls $l|wc -l); done\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 划分数据集\n",
    "\n",
    "因为这里只使用了 train 集，因此我们需要对已有数据集进行划分，供模型训练的时候做验证使用，也就是 validation 集的构建。\n",
    "> 一般认为，train 用来训练模型，validation 用来对模型进行验证以及超参数（ hyper parameter）调整，test 用来做模型的最终验证，我们所谓模型的性能，一般也是指 test 集上模型的性能指标。但是实际项目中，一般只有 train 集，同时没有可靠的 test 集来验证模型，因此一般将 train 集划分出一部分作为 validation，同时将 validation 上的模型性能作为最终模型性能指标。\n",
    "\n",
    "> 一般情况下，我们不严格区分 validation 和 test。\n",
    "\n",
    "这里将每个文件夹下面随机50个文件拿出来做 validation。\n",
    "\n",
    "```sh\n",
    "export train=train\n",
    "export val=validation\n",
    "\n",
    "for d in $(ls $train); do\n",
    "    mkdir -p $val/$d/\n",
    "    for f in $(ls train/$d | shuf | head -n 50 ); do\n",
    "        mv $train/$d/$f $val/$d/;\n",
    "    done;\n",
    "done\n",
    "```\n",
    "\n",
    "> 需要注意，这里的 validation 只间接通过超参数的调整参与了模型训练。因此有一定的数据浪费。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型训练代码-数据部分\n",
    "首先导入 mxnet 看一下版本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1.6.0'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import mxnet as mx\n",
    "\n",
    "mx.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练模型的时候，模型内部全部都是数字，没有任何可读性，而且这些数字也需要人为给予一些实际的意义，这里将 100 个汉字作为模型输出数字的文字表述。\n",
    "\n",
    "需要注意的是，因为模型训练往往是一个循环往复的过程，因此一个稳定的文字标签是很有必要的，这里利用相关 python 代码在首次运行的时候生成了一个标签文件，后续检测到这个标签文件，则直接调用即可。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "if os.path.exists(\"labels.txt\"):\n",
    "    with open(\"labels.txt\") as inf:\n",
    "        classes = [l.strip() for l in inf]\n",
    "else:\n",
    "    classes = os.listdir(\"worddata/train/\")\n",
    "    with open(\"labels.txt\", \"w\") as of:\n",
    "        of.write(\"\\r\\n\".join(classes))\n",
    "\n",
    "class_idx = {v: k for k, v in enumerate(classes)}\n",
    "idx_class = dict(enumerate(classes))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "pyTorch里面，classes有自己的组织方式，这里我们想要自定义，要做一下转换。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "\n",
    "pth_classes = classes[:]\n",
    "pth_classes.sort()\n",
    "pth_classes_to_idx = {v: k for k, v in enumerate(pth_classes)}\n",
    "\n",
    "\n",
    "def transform(data, pth_idx):\n",
    "    return data, class_idx[pth_classes[pth_idx]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "mxnet 中提供了直接从目录中读取数据并进行训练的 API 这里使用的API如下。\n",
    "\n",
    "这里使用了两个数据集，分别代表 train、validation。\n",
    "\n",
    "需要注意的是，由于 数据中，使用的图像数据集，其数值在（0， 255）之间。同时，mxnet 用 opencv 来处理图像的加载，其图像的数据 layout 是（H，W，C），而 mxnet 用来训练的数据需要是（C，H，W）的，因此需要对数据做一些转换。另外，train 数据集做了一定的数据预处理（旋转、明暗度），用于进行数据增广，也做了数据打乱（shuffle），而 validation则不需要做类似的变换。\n",
    "\n",
    "> ToTensor这个操作会转换数据的 layout，因此要放在最后面。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from multiprocessing import cpu_count\n",
    "\n",
    "transform_train = mx.gluon.data.vision.transforms.Compose(\n",
    "    [\n",
    "        # mx.gluon.data.vision.transforms.RandomRotation((-15, 15), zoom_out=True),\n",
    "        # 带随机旋转的版本还没发布\n",
    "        mx.gluon.data.vision.transforms.Resize((128, 128)),\n",
    "        mx.gluon.data.vision.transforms.RandomColorJitter(brightness=0.5),\n",
    "        mx.gluon.data.vision.transforms.ToTensor(),\n",
    "    ]\n",
    ")\n",
    "transform_val = mx.gluon.data.vision.transforms.Compose(\n",
    "    [\n",
    "        mx.gluon.data.vision.transforms.Resize((128, 128)),\n",
    "        mx.gluon.data.vision.transforms.ToTensor(),\n",
    "    ]\n",
    ")\n",
    "\n",
    "\n",
    "img_gen_train = mx.gluon.data.vision.datasets.ImageFolderDataset(\n",
    "    \"worddata/train/\", transform=transform, flag=0\n",
    ")\n",
    "\n",
    "\n",
    "img_gen_val = mx.gluon.data.vision.datasets.ImageFolderDataset(\n",
    "    \"worddata/validation/\", transform=transform, flag=0\n",
    ")\n",
    "\n",
    "batch_size = 32\n",
    "\n",
    "img_train = mx.gluon.data.DataLoader(\n",
    "    img_gen_train.transform_first(transform_train),\n",
    "    batch_size=batch_size,\n",
    "    shuffle=True,\n",
    "    num_workers=cpu_count(),\n",
    ")\n",
    "img_val = mx.gluon.data.DataLoader(\n",
    "    img_gen_val.transform_first(transform_val),\n",
    "    batch_size=batch_size,\n",
    "    num_workers=cpu_count(),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "到这里，这两个数据集就可以使用了，正式模型训练之前，我们可以先来看看这个数据集是怎么读取数据的，读取出来的数据又是设么样子的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((32, 1, 128, 128), (32,))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for imgs, labels in img_train:\n",
    "    # img_train 只部分满足 generator 的语法，不能用 next 来获取数据\n",
    "    break\n",
    "imgs.shape, labels.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看到数据是（batch, channel, height, width, height）, 因为这里是灰度图像，因此 channel 是 1。\n",
    "\n",
    "> 需要注意，pyTorch、mxnet使用的数据 layout 与Tensorflow 不同，因此数据也有一些不同的处理方式。\n",
    "\n",
    "把图片打印出来看看，看看数据和标签之间是否匹配\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'利'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAD7CAYAAABqkiE2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nO2de5QcVdXof7sf80xCXhACgY8oXK5cloYY5BV5iCBETJCnEr0IRHzBFb0LJfFbAmvxAd7rQtRPIRFQiIGEDwLEXDFgAEGBmAkgDzEQlEAgQEIS8phMJj197h/du7qmu2emp7urq6dr/9aaNT3V3XX2nKo6Z5+999lbnHMYhhFdYmELYBhGuNggYBgRxwYBw4g4NggYRsSxQcAwIo4NAoYRcQIbBETkFBFZLSJrROSKoNoxDKMyJIg4ARGJA68AJwHrgJXAF51zf696Y4ZhVEQioPN+AljjnPsngIgsBGYARQeBsWPHugMOOCAgUYx64JlnnvFeT548OURJosuqVas2Ouf2zD8e1CCwL/Cm7+91wBH+D4jIxcDFAPvvvz8dHR0BiWLUA01NTQCIiF3rkBCRtcWOh2YYdM7Nc85Ncc5N2XPPgsHJaDBEBBEhHo+zfPlyli9fHrZIRpagNIG3gP18f0/IHjMixrhx4wBIp9MAOOc45ZRTANi9e3dochk5gtIEVgIHichEEWkCvgAsCagtwzAqIBBNwDmXEpFLgGVAHLjNOfdSEG0Z9c2mTZuAnCYAMHLkyLDEMYoQ1HIA59zvgd8HdX7DMKpDYIOAYUBvDUDp7OwMQRKjLyxs2DAijmkCRqDE43EAenp6AIjFYp52cO+99wJw5plnhiOcAdggYATIrFmzCo6l02laWloAe/jrBVsOGEbEMU3ACIzbb7+9wDAYi8VIpVIhSWQUwzQBw4g4pgkYgZFOp4nFYt5r/a1GQqM+sEFgkFx77bVs3boVgEQi033XXHNNmCIZRkXYcsAwIk5kNYGf/exnAHR3dxeorHPmzKG5uRmAHTt2ADl/t4h46qweu+GGG7wdcdu3bwfwvm/0RvvaqB/sihhGxKk7TWDevHlAZsb1UywXYltbG1u2bAHgkksu8Y7/+te/BnL71V944QUAbr755qKx7MXQ+HaVo9j3VCb/vvgRI0YA0NXVVfA/RI329nZ27tzZ61g6nfY0KKM+qKtBoKmpyfMhlzIIJBIJ74a6/PLLve/pQ6mGOz1XU1MTXV1dwQifRdNoNTU1RT5pxvTp07n77rt7Hav35cD8+fOBXLSjP6ahUb0a9X1FDMMInLrQBLZu3cqyZcvYvXs3yWQSKJz5i6njPT093uisnxeRgu/qZ6qhnjvnvPPorBaLxbw2VM5YLMadd94JwHnnnVdxu0ORRYsWFRxLp9OehlYvPPjgg0BGc8m/V/xG46997WsAzJ07NwQpg8M0AcOIOIEUHxks8XjctbS0sHPnzoIRWCnVoBcU6vLbtWsXw4YNA3LuQOcc7e3t3vvQW0uJKs3NzUX3Cei1Ddtmolqnaibd3d3ee34tDzL3XxjytrW1ATlt5bjjjiv7XCKyyjk3Jf+4aQKGEXHqYnF22GGH0dHRQTwe96z99aCh+NEZPhaLebPDBRdcAMCHPvQh5syZE5ps9ca0adOAjGW9v6QiYaAuyxEjRvRrIypm96k1qqkAXpr26dOnF7W1VEJdDALK1VdfzXXXXQfkXDM6GDQ1NfUyuuln1CVXTO3UG9B/rlJvQP1uvk873+9dTzQ1NfVSacPC/+AUM/CGMcBPnToVgJUrV3pyKP4BSh/+ernuKqde1yVLqp+535YDhhFx6koT+Pd//3f22GMPILczb9u2bUBmNtcZREdHf2DQEUdkSh0effTR3rGbbrqp1/kHo4rqOb70pS8BcNttt5X3T9WQmTNneipk2Ea3vvC7U4NG+6CtrY3W1taSvpPvcn700UeDEa5MgliWmCZgGBGnLlyEU6ZMcbWoVOtf8w3EueeeC+AF/AwV1KWks2AYGsFpp50GwLJly7xjfg0sX6MLCp39u7u7CwKUitmQ2tvbPQPwihUrAJg0aVKgMvZHMpks6KOWlhY+97nPAbBw4cJBna8vF2HZywER2Q+4AxgHOGCec+6nIjIaWAQcALwOnOOc21xuO9UknU4zfPhwAG8PQTHVdMaMGUPu4VfGjx8PwBtvvAFQsE26Hghqm7U+2KNGjQLw4jk2bdpU8NDHYrECo/KuXbtYtWoVAIceemggMlaKc44HHnigquesZDmQAv63c+4Q4EjgWyJyCHAFsNw5dxCwPPu3YRh1StmagHNuPbA++3qbiLwM7AvMAI7Pfux24DHg+xVJWUVUvdMRP5FIeGqjRgDed9994QhXBV577TUgt5vRv4OyXtyHqp0sXrwYgDPOOKMq59Yy6LoNXK9nX8Y0XZaoO/Dpp5+uWw1A6enpqbpWVxXDoIgcABwGrADGZQcIgHfILBeKfediEekQkY4NGzZUQwzDMMqgYhehiAwD7gUuc85t9RvenHNORIpaHp1z84B5kDEMVipHqTzxxBNAboZsaWnxZoxjjz22VmIEzi9+8QsAb+cbwI033gjAZZddFopMkLFNBFF3YOLEiZ6dR20O/rwSqgX5tQLVBDTpzIEHHlh1uaqNPwrzG9/4BlDoCh8sFQ0CIpIkMwAscM4tzh5+V0TGO+fWi8h44L2KJAwIHQR27drlWY7/+Mc/hilSVfnqV78K4IUzd3Z28t3vfhfIRb/Nnj07kLb9iWHyvU/+h7AaywC/B0CzOmk2aMWfSEZLoKVSKVavXg3A/vvvX7EctSKIrExlLwck8xTdCrzsnLvB99YS4Pzs6/OB6poyDcOoKmXHCYjIVOAJ4AVALRVzyNgF7gb2B9aScRFu6u9ctYoT8OOfkYptJW00YrFYr1kTgoshOOGEEwB48sknexUdUTSWQaNBy0G3but11KzQUJg8ZtiwYQVZoF9//XX22muvstuvBcXiBCDn+vzggw8Gdb6qxwk45/4M9BV5c2K55zUMo7bU1d6BsIhCVuB0Ou3NwEFv63388ce9Nou55ypN9trW1uadQ6+d/xqqe1Sj/7Zt2+bZAt5//32AkvcShEmx+zIWi3ku0O985zsA/OQnP6moHds7YBgRxzQBchbXvqzLjcL9998PwOmnnw5kZuTRo0cDcOGFFwLw4x//uOJ2VOPo7OwsmlREbTD33nsvAGeeeWZJ51U7wEAuRrV56Gwfi8Uqsj+ERTKZLBrWrv03WJtAX0RqA5Efv5qqN8uYMWOAXNx9o6L7J3p6erz/XYu4NDc3e+pmuejD2tnZWRCf71+C6M3c3Nzc58Db3t5ekGAmmUx6bs7+lnL1ksuwXEaMGNHL4Am5pQ7kMg+VOmlZjkHDMIoS2eWAziqxWMybVXQ2bHRUNR49ejSbN2c2eKp2EIvF2GeffYCcEU0NbAOhW1zVaBeLxQpUd78hUt9LpVJe+7ok01ByvzqsM3t3d7e3zJgwYQIAa9eu9bSCRsnyPHfuXGbOnAnkZv1UKuX1YbUiL00TMIyIE1mbgI6sY8eO9Qwsla6FhyLaDzq7xOPxghRbusYfyLimgTjlzFD+0u+DOUexZCWNRL4xNJ1OFyTcLdXmUfVgoaGOGsQ2btxYVwk3ak3+DeSPTdc+0uVAc3Nz0aWBZmGqhFIe4GIxB42i+veFZt/WTV/F9g5cddVVXHXVVWW3YcsBw4g4kdUE1K3S1tYWyNbWoUpPT4+n1udH9sViMW8m0s845zztwL9Db7Dka2M66yeTSU9L0PMefPDBvPjii4NuYyhy+OGHA7llQbFoy1deeaWiNkwTMIyIE1lNQGlEY1Kl9GUjaW1t9XbjjRw5Esi48tR2UM2+LFYhaKgG/VSC5oVQo3Uxu0iltQgiPwikUqlIbCAaDPqwqc9eo9b8RsF33nnHe61xFvlpvQeDLi/0HNrm1KlTeeihh8o+71BHE5/0R6UFUmw5YBgRJ/KagGkBfaPFQ048MZMeoqenx4srUCOgLg/0/XJRDUCNf43u+isVf7ZoKG503bhxY0VtmCZgGBEn8pqA0TdHHXUUkDNKjRo1youurLZbVV1fjZzirRzyt2IXo9JgN9MEDCPimCZALn7e6J/Nmzd7QTqTJ08GMjOUugZLta/o59UjkE6nS96pGFX6CuCCyrwyYIMAiUSiV6IGo3+0TJeq7U899RTHHHMMUGjc64v8QdcGgL7RZUB/E1Wlxm1bDhhGxImsJnDllVcCmVnIogbL56ijjvIMU6VqVNrfGmRk9I0aBvuLlty1a5cXMKQ1HwaDaQKGEXEiqwnccEOmclpLS4vNSFXi4x//OAB//etfvWMa1+5PhGGaV+nceuutAFx00UVAb3eg/7Xez+VoAtWoShwHOoC3nHOnichEYCEwBlgFfNk5V3fOX82SY56B2pJOp3nkkUfCFmPIcPbZZwO5jURBUI3lwLeBl31//wj4iXPuQGAzcFEV2jAMIyAqGgREZALwWeCW7N8CfAq4J/uR24HTK2kjaCxGvXokEokBfdZadyDKKd0GQzKZ9Oos9Ldsdc6VvcyqVBO4EfgeuarEY4Atzjl1FK8D9i32RRG5WEQ6RKRD00sbhlF7yh4EROQ04D3n3Kpyvu+cm+ecm+Kcm7LnnnuWK4ZRR6RSqQEDhdLpdEWzVlSJx+NFk4xWg0oMg8cA00VkGtACjAB+CowUkURWG5gAvFW5mIZhBEXZmoBzbrZzboJz7gDgC8AjzrmZwKPAWdmPnQ88ULGUxpDgqaee4qmnnhrwcyJieRyqTCV9GkSw0PeB74rIGjI2glsDaKNqxOPxkgxaRnWoNB9eVDniiCM44ogjir7X1NRET09P2Ubuqtz5zrnHgMeyr/8JfKIa5zUMI3hs+jNqirkGy+Paa68FikcExmIx2trayj636WaGEXEirwkkk0mrQFRDamkT0Oo9K1eurFmbQXHsscf2+/4999zT7/v9EflBoKury4yCVaZYFhxVVzs7O2s26P7jH//w5GmUxCUtLS0FeRhTqRSnn54JzL3//vsHfU5bDhhGxIn8FOgveGlUB/VXx+NxL9GILgMSiQQnnXRSoO0fdNBBQK4mwrBhwwJtr5b4tSg1sqbTaZYvX172OU0TMIyIE/lBwFxW1aerq4uuri5SqZS3Y7Czs5POzk6vxHaQaJsaRbd9+3aam5s9W8VQ5sknn6SpqYmmpiZaWlpoaWkhnU6zffv2XtWgBkPklwM9PT0NaRjUzSb6v913331MmzatJm3r8ioWi3nLAFVjH3744cDb112p+r8757ziqkOdzs5OL9+gP0Jw9OjRZZ8z8pqAYUSdxpsCI86ll14K5Nx0OgOfe+65Xkq1WqKJMNRAWIuNQzpDqkbS09PTMLUl2tvbe2k4emzz5s0AfPKTnwTgiSeeKPmcpgkYRsQxTYDqF9cMEzW86QysiVSLla8KCr8BTmflWvbx3nvvDcDrr78OZGbM9evX16z9IPHvJFSj9o4dOzwNq5x+Nk3AMCKOaQLkZi4NMnn11VfDFKciVAPInxmCSk1VDN3p9uCDDzJ27FgAr6R5LVyEr732GpALVU4kEqHYQ4Kgra3NcwVqX+7YsaOic9ogQK4gpt6oQxl92HVgCyNmXg1WiUTCG4TUrXXwwQfXTA7/EqhR8lguXbqUz372s0CudoZ/U5a/8Eup2HLAMCJO5DUB/76BRkh9ruWofv7znwO5ZUEtIyP9Sw9VXYcPH16z9vNJJBK8//77obVfTUaMGOFpOOr2rNToapqAYUScyGsCjUq97InQwJYwKj2NHDkSaAxbj7Jjxw6vL/MDwsrFBoEGJb8acFjoDRpGlmG1mmvps0bg6KOPpqWlBahe7IUtBwwj4kReE4jFYp5xsBHKlH/7298Gcup3GOr40qVLgd4GwjBmYjWcdXZ21rztINH0YtXSskwTMIyIE3lNoLm5uWDX2VBOTFmva98wUrjp2rnRNAENvFLNtdJrXtEgICIjgVuAQwEHXAisBhYBBwCvA+c45zZXJGWAdHZ2egkZ1JA0VAcAyN34Gh+gA1xY5b90SRBG7UGNUYjFYqF4J8KgnOtc6Z3xU+APzrn/DnwMeBm4AljunDsIWJ792zCMOqVsTUBE9gCOBb4C4JzrBrpFZAZwfPZjt5OpUfj9SoQMGvUjN1LWYTWKha3V6AxcSZmsStuGnIbUCLzxxhtA9ZZ+lWgCE4ENwK9F5FkRuUVE2oFxzjndvP0OMK7Yl0XkYhHpEJGORgjXNYyhSiWDQAKYDNzknDsM2EGe6u8yU2vR6dU5N885N8U5NyXsHV7OuYbRAnbv3s3u3bu9UtXxeLym24jzCbPsu15XEQm9H6qJZlPW/8lvB9AMy5MmTWLSpEklna+SQWAdsM45tyL79z1kBoV3RWR8VqDxwHsVtGEYRsCUPUQ7594RkTdF5GDn3GrgRODv2Z/zgeuzvx+oiqQB0ihaAMDcuXOBnPsov25dLXjyySeBjBbgT4FVazSYZtSoUWzatKnm7QfF9ddfD+Q8Lv77V18PxhtTqZ52KbBARJqAfwIXkNEu7haRi4C1wDkVtmEMglmzZgEwb948oPiNEjTXXnut12Y9xC1cc801YYtQVebPnw/0HzH40ksvlXy+igYB59xzwJQib51YyXkNw6gdkY8YbDQuu+wyIKcJhBEstHbtWiAzU6kxLoxgnUZa5vnxF3eFyl2FtnfAMCKOaQLkjGi1yIQbNB/+8IcBOPzww4Fc4sla1SEEWL16tfc6jHDhRqe/vRCqFQxG87NBgJyq2kjx5X/+859Da1sz3ohIXRgGGw2NBq2W58eWA4YRcUwTILfTrVEiysLiqKOOAnonvVC1tFGNdLVm5cqV3utqpZAzTcAwIo5pAtBQ6cXCRDUpfwJM7VMzEFaHZDJZdXevaQKGEXFME/Bh69bqo1rBCy+8ELIkjcExxxzj2Vw0WKhSL4FpAmRcLk1NTSSTSVsSVIBuX/aj23kbaStvmHR1dXlbibu7uwccAPbbb78Bz2mDgGFEHFsOkFOn3n777ZAlaTw0sMWojCCjWU0TMIyIY5oAjRUuHCarVq0qOKY58g888MBai9MQjBgxAgi2noQNAuR82HvssQfQWFVsw6aRsvzWmvz8gYNBB41169YN+FlbDhhGxDFNgJzKtXXr1pAlaTw0TkCNrvvss0+Y4gwJWltbgYxRtRY5Ik0TMIyIY5oAufVTGFVyGp2uri4ANm7cCJgm0B/qTlXtSUQKbAJBGAhtECC3DBg1alTIkgxt/Devoq8/9rGPARaaXQx9+LVvNLIymUx63hV9+Nvb27307aXkbyzFsGjLAcOIODYIAMOGDWPYsGEW314hWgLLj+7LaG1tpbW1lRUrVvTx7ejS0tJS1JXa09PjvReLxYjFYmzbts3bO1AtbBAwjIgTeZuAPyDDAluqj9oJdOY6+uijLULTx6hRo7zy8dovmqg1lUrxy1/+EoAvf/nL3nfUhqA7XgfqT3WB90VFmoCIfEdEXhKRF0XkLhFpEZGJIrJCRNaIyKJsiTLDMOqUsjUBEdkX+F/AIc65nSJyN/AFYBrwE+fcQhG5GbgIuKkq0hp1yRlnnAGUbq2uVoLMRqCrq8ub2bU/VDNYtGgRZ599dsF3NICo1B2aA3lkKl0OJIBWEdkNtAHrgU8B52Xfvx24iiEyCFgevPLYsmULkMt0098gkE6nGT58OACLFy8GcoNIlNABc/To0d5eFX2of/Ob3wAUHQD8aGGZj3/84xXJUvZywDn3FvBj4A0yD/8HwCpgi3NOM02uA/Yt9n0RuVhEOkSkY8OGDeWKYRhGhVSyHBgFzAAmAluA/wJOKfX7zrl5wDyAKVOmhBZBUqy2uzE4HnnkEaB43YZiKv/27dsBuOCCCwAYM2YMxx13XIAS1g9qzNN7bceOHd7ru+66C4Dp06eXdK5JkyYBmSVWfvkxf7/3V7YMKjMMfhr4l3Nug3NuN7AYOAYYKSI6uEwA3qqgDcMwAqYSm8AbwJEi0gbsBE4EOoBHgbOAhcD5wAOVCmk0Frr21X0FJ510Uk12y4VNc3Nzr/BfyPTBggULgNI1gGpT9iDgnFshIvcAzwAp4Fky6v3/AxaKyDXZY7dWQ1Cj/lG1diADa/4eg0aN0pw1axYAd9xxBwAjR470jKhKLBbjnHPOqblsfiryDjjnrgSuzDv8T+ATlZzXMIzaEfmIQX/57GqXd4oKe+65J1CaixAKjYU9PT0cfvjhQO+Cm0OZWCzm7UpVTWfLli1eNKCiMQGVtqWUE3thd71hRJzIawKQW8uai3DwXHPNNZ6Br5SIwb54/vnnqypXWKjRM5FIFLjm/Fqn5gSoBt3d3Z7rUdsfjKE18oOAlh+D3tV0jdL44Q9/6L2uxMCnA7CmLa80Cq5WXHllxiR23XXXATljZywW81R9zRmYTqer+vArsVjMG3iLJXYZ8PtVl8gwjCGF1IMKPGXKFNfR0RFK201NTd4oqgYWTelk9I3fz10NtO/1WmjyETUY1iv+zMCQ02j895BuUd+8eXMgMsTj8ZLcs+l0epVzbkr+cdMEDCPiRN4mADnXllYgMvpG7SfV3nGp10DPe+yxxwLwl7/8hcmTJ1e1rWqxxx57eIY+nfl37twJZGZ/Nc7psXrFNAHDiDimCfiYP39+2CLUNclksiqBVfkBM4lEwrMtqKahHHnkkXW7r6C7u9tb72t/6Nq8ra2tZhqA366l+P/O36uRjw0C5Nwqn/nMZ0KWJBqo6q83ZSqV8m7a/KhD55yXhGTbtm21FrVfurq6imZXBnj//fdrJkd3d3dB4RI/OlD1NQjYcsAwIo5pAuSCXJ577jkgl6zByHDvvfdW9XynnXYaAHfeead3TF2OutxQtToWi3kzmLrj6sXQlkgkPI1l9OjRAKxfv77mcviXA8VUf3VN9mXMNU3AMCKOaQLkZp2TTz4ZgPfeey9MceoGNdLpb/8OtVJ3q+l61B+U5tcAlPxwWp31U6mU15b+njRpkqe1hUkikWCvvfYCYO3ataHJ4e8j7ef8DMb9YYOAj3qInqwHHnggkwxKjXSVRFCqWqo3Y6kPrxppf//73xfcyK+88opXRHagwhpB8r3vfY+rr746tPb7Q/vdCpIahjEgtnfAZ1QZNmwYgJcHPor09PR4s2u+AU5E+tWWihme9POqVQzW59/a2up9x6+ZaFtRLmmmbsi99967IH7Dv5dBtzSLiO0dMAyjkMjbBOLxuLfm1XVmlGltbfVm2fyZPZFIeDOMBu586Utf8vLlF8t5r+fQWXyw7Ny5k7a2NiBnmxARry19b6Dc+o3IsmXL+nxPr4GmOOuPyA8Cllcwg/ZDPB7vMzFFPB4viNr77W9/y8iRIwG81Nn+wbQa/Xv00UcD8Kc//aniczUCixYtAuCiiy7q8zM6CJSS6MWeAMOIOJE3DPpVSx0163XDShCsWbMGgI985CPeMdUA9N5QX/9Asfuf/OQnAdBrmU6nvWWAaheVZNf1xy3otdL4AudcZFy8+fsEihlktd/9/W2GQcMwihJ5mwDkkolE0TX40Y9+tOBY/o6+UnfvPfHEEwC9DIs6W+nsVQlqGBw+fLgXDKNu3agkiX3uuec8jUejKovtDix2XftiQE1ARG4TkfdE5EXfsdEi8rCIvJr9PSp7XETkZyKyRkSeF5H6TAljGIZHKZrAb4D/BO7wHbsCWO6cu15Ersj+/X3gVOCg7M8RwE3Z33WNlsqOmqfg8ccfLxoSrBrAGWecUdZ5dYZqb2/3+racyjh9sW3bNq/qkdb2SyaTns2gkRPFrl+/3nOLat8WYzCVnAYcBJxzj4vIAXmHZwDHZ1/fDjxGZhCYAdzhMvrK0yIyUkTGO+dqv7+yRETEu2lK8ak2EjNmzCiINBs2bBif//znAbjtttvKOq+qp7t27QrMWLdhwwYgFyfgnPOWBMXiFRoF3YZdTcqd+sb5Hux3gHHZ1/sCb/o+ty57rAARuVhEOkSkQy+oYRi1p2LDoHPOicigh3vn3DwypcyZMmVKaL4dEWHMmDFALhZ7zJgxNU0PVWs06CaRSHj5/nT23rp1a9kaQDH0/NUovNnf+VOpVIFrs5KyaPXGk08+Gdi5y9UE3hWR8QDZ37oB/y1gP9/nJmSPGYZRp5SrCSwBzgeuz/5+wHf8EhFZSMYg+EG92gPUiBSLxdi4cSOQC4pp1D0EixcvBuCss84CMrOoBkbprFlto5rOzmosfOaZZ6paR8CfOstfBxDwwpkbQbMrNaXaJz7xiUGfe8BBQETuImMEHCsi64AryTz8d4vIRcBa4Jzsx38PTAPWAJ3ABYOWqEboTZlMJr2HvtF9zbNmzQJ6Z/vVbcNBxUjoMkAHmVNPPZV333236u0457yHX/8nva6xWKxuMxaXyqc//WmguAfL74Uph1K8A1/s460Ti3zWAd8qWxrDMGpOZCMGdZaIx+ORiA+47777vDh7v9o8c+bMQNtVrUONdUEV5YScS3DixIlA7honk0nvf1fNJL8ASr2jS9Vi+1p27drl9fNTTz016HM3/t1vGEa/RFYTUKKy8+zMM88smP0uv/xyrr/++qq3dfzxxwO5vQSQm3krWbuWyr/+9S8gVwtg+/btnvajwUVDxW04blwmBEcN2cU0gVQq5e3gLAfTBAwj4kReE/AX2WxE9t9/fyBj+9DZRNfFQWgBkJttobC2YC3rBWzatAnIXGOVQ3+/8cYbXt/UK8uXL/fsGv3do21tbTz22GNltxPZQcBvHGtEw+Dvfvc7ANatWwdkDEv68AcVvVeM/OVWGDEYJ5xwgveQ+I2H9b4kOPnkkwuiH4tx3nnnVdRO4939hmEMishqApqMoqurq6GWA7/97W8BuOCCTJyWvwxYLTUARQ2CqgFo4dFa8tBDD3lLIX/9g8MOOwyAZ599tuYy9Yf2WTwe97QVvUdbWloK0ovNnTu3ovZMEzCMiBNZTSA/cKZR+NrXvgYUFqT84Q9/GIo8OlvpDFzNfQODQV1o6rZMJBK89NJLocjSF5rstT+3tb/4aLVKtEd2EPBvQW2k5UB+BV9l9uzZYYjj+bXDzvbzyLrqG2QAAAllSURBVCOPAL3Lc6mqrXENlVjYK2Hp0qVALvOzyrhz586itR+mT59e1fZtOWAYESdymsA3v/lNIDdD9eUeDDoZRhC0t7cXFP4MQ37VQkSk1+t6QPtjxIgRXiKVFStWhCkS55yT2YSrfaVy+WM7/MVYFy5cWNX2TRMwjIgTuUFgwYIFLFiwwAsS6iuHQDqdJp1OM2fOHObMmVNjKQdHPB7vVZFnx44d7Nq1KzQtpq2tjba2troMxtHkIyeccIJ3rKenh56eHpYsWVJzeR599NFeCVHyUdn0GlfLGOgnEmXIvv71rwNw++23FywD+hoI/OoX1HfmWv1fRKSuHrxkMund3NrH9dSP+QVR7r//fqZNm1ZTGQ499FBeffVVoHhSG7221TCsWhkywzCKMuQNg/PnzwdybpYlS5Z4s6FqOTob+WfJvlxpiqrX9aApDcTDDz8MwIknFiR7ChV/39aLYdDP22+/DeQ2WZ199tle/EitePnll/t8r6WlhV/96leBy2CagGFEnLrXBHQ3XDqd5uyzzwZy8eednZ0FiTLUkAKFUXOJRKLkZKL6OT3Xz3/+cwAuvfTSsv+XoKg3DWCoMHbsWCDXfxpQVC+cddZZFe8QLAXTBAwj4tSVd2DFihVMnTq1onNVy/qcH0Tkr3P30EMPATYDD0Q8Hi/QxoZS8FUt8HtQFPVMdXZ2VrWtvrwDdbUcmDp1al25kPzohWlpafEKdhYzIumg2tLS4kV+6ZIlCB9vvaPLqf4MYFFEB8dYLOa9VvfkokWLaiqLLQcMI+LUhSbQ1dXFK6+8Qmtrq5eNtr90Sv0RlCahy4Pu7u5+3YYa651KpTztQZcS9ZrEIkj0Our/3qgl3kpFtUK9h3p6ejjjjDMAuOuuu0KRaUBNQERuE5H3RORF37H/KyL/EJHnReQ+ERnpe2+2iKwRkdUi8pmgBDcMozqUogn8BvhP4A7fsYeB2c65lIj8CJgNfF9EDgG+APwPYB/gjyLy35xz/caybtq0iQULFrBz505vpNTwXs1c29nZOejAnWIBKsV2DeoxPX8ymew1UkNuNh9IBv9aT20Cl1xyCQA33HDDoOQf6jjnPM3o1FNPDVma8Glvb/fuI9WQ4vF4aBqAUpJ3QEQOAJY65w4t8t7ngbOcczNFZDaAc+667HvLgKucc/3WRiq2d8Af7w+ZQaGUrMBf+cpXCo5pXMGNN9444PchU7hTb95bbrkFyA0oA/XXhRde6L2uNPfbUEUf+D/84Q/ekuhvf/sbAIccckhocoWF3ktNTU299qwAvPnmm14B1aAJcu/AhcCD2df7Am/63luXPVZMoItFpENEOjZs2FAFMQzDKIeKDIMi8gMgBSwY7Hedc/OAeZDRBPLfv/nmm3v9riU6+wPcdNNNNW+/UYjFYt4sGEaW4bDxF2HRv1Wj1F2BtdIC+qPsQUBEvgKcBpzocjryW8B+vo9NyB4zDKNOKWsQEJFTgO8Bxznn/GFNS4A7ReQGMobBg4C/ViylMaRQo1csFvNmvFL3bDQK/vJ2uv7fsmWLpwkUKywaFgMOAiJyF3A8MFZE1gFXkvEGNAMPZ/+pp51zX3fOvSQidwN/J7NM+NZAngHDaCSGDx8OZAzJqur76wnWQ5h+PgMOAs65LxY5fGs/n/8P4D8qEcowjNpRFxGDRuOiLkItrFFPanA10DgSnfX9W9A/+OADIKcd7LPPPiFIODC2d8AwIo5pAkag6Myfn9SzEXj//fc58MADgdwOUS1060e1oKeffrp2wg0C0wQMI+KYJmAEiu4FqadU6NViwoQJBUlS1D060G7TesIGASNQhkLdhsGiS5tUKuU96HvttReQy1y8cuXKcIQrA1sOGEbEqYscgyKyAdgBbAxbFmAsJocfk6M3Q1mOf3PO7Zl/sC4GAQAR6Si2zdHkMDlMjmDlsOWAYUQcGwQMI+LU0yAwL2wBspgcvTE5etNwctSNTcAwjHCoJ03AMIwQsEHAMCJOXQwCInJKtk7BGhG5okZt7icij4rI30XkJRH5dvb4aBF5WERezf4eVSN54iLyrIgszf49UURWZPtkkYgEvgNHREaKyD3ZmhIvi8hRYfSHiHwne01eFJG7RKSlVv3RR52Non0gGX6Wlel5EZkcsBzB1PtwzoX6A8SB14APAU3A34BDatDueGBy9vVw4BXgEOD/AFdkj18B/KhG/fBd4E4yqd0B7ga+kH19M/CNGshwOzAr+7oJGFnr/iCTnfpfQKuvH75Sq/4AjgUmAy/6jhXtA2AamUzbAhwJrAhYjpOBRPb1j3xyHJJ9bpqBidnnKV5yW0HfWCX8s0cBy3x/zyZT2KTWcjwAnASsBsZnj40HVteg7QnAcuBTwNLsTbXRd8F79VFAMuyRffgk73hN+4Nc2vrRZPa2LAU+U8v+AA7Ie/iK9gEwF/hisc8FIUfee58HFmRf93pmgGXAUaW2Uw/LgZJrFQRFtrjKYcAKYJxzbn32rXeAcTUQ4UYyiVt1l80YYItzTrNz1qJPJgIbgF9nlyW3iEg7Ne4P59xbwI+BN4D1wAfAKmrfH3766oMw792y6n0Uox4GgVARkWHAvcBlzrle1TJdZlgN1IcqIqcB7znnVgXZTgkkyKifNznnDiOzl6OXfaZG/TEKmEFmUNoHaAdOCbLNwVCLPhiISup9FKMeBoHQahWISJLMALDAObc4e/hdERmffX888F7AYhwDTBeR14GFZJYEPwVGiohu9a5Fn6wD1jnnVmT/vofMoFDr/vg08C/n3Abn3G5gMZk+qnV/+OmrD2p+7/rqfczMDkgVy1EPg8BK4KCs9beJTEHTJUE3Kplc6bcCLzvn/JVClwDnZ1+fT8ZWEBjOudnOuQnOuQPI/O+POOdmAo8CZ9VQjneAN0Xk4OyhE8mkjq9pf5BZBhwpIm3Za6Ry1LQ/8uirD5YA/zPrJTgS+MC3bKg6vnof011hvY8viEiziExksPU+gjTyDMIAMo2Mdf414Ac1anMqGbXueeC57M80Muvx5cCrwB+B0TXsh+PJeQc+lL2Qa4D/Appr0P4koCPbJ/cDo8LoD+Bq4B/Ai8B8MlbvmvQHcBcZW8RuMtrRRX31ARkD7i+y9+0LwJSA5VhDZu2v9+vNvs//ICvHauDUwbRlYcOGEXHqYTlgGEaI2CBgGBHHBgHDiDg2CBhGxLFBwDAijg0ChhFxbBAwjIjz/wHRpO9Mqjw+CwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "plt.imshow(imgs.asnumpy()[0, 0, :, :], cmap=\"gray\")\n",
    "classes[labels.asnumpy()[0]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型训练代码-模型构建\n",
    "\n",
    "mxnet 中使用动态图来构建模型，模型构建比较简单。这里演示的是使用 class 的方式构建模型，对于简单模型，还可以直接使用 Sequential 进行构建。\n",
    "\n",
    "这里的复杂模型也是用 Sequential 的简单模型进行的叠加。\n",
    "\n",
    "> 这里构建的是VGG模型，关于VGG模型的更多细节请参考 1409.1556。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyModel(mx.gluon.nn.HybridBlock):\n",
    "    def __init__(self):\n",
    "        super(MyModel, self).__init__()\n",
    "        # 模型有两个主要部分，特征提取层和分类器\n",
    "\n",
    "        # 这里是特征提取层\n",
    "        self.feature = mx.gluon.nn.HybridSequential()\n",
    "\n",
    "        self.feature.add(self.conv(64))\n",
    "        self.feature.add(self.conv(64, add_pooling=True))\n",
    "\n",
    "        self.feature.add(self.conv(128))\n",
    "        self.feature.add(self.conv(128, add_pooling=True))\n",
    "\n",
    "        self.feature.add(self.conv(256))\n",
    "        self.feature.add(self.conv(256))\n",
    "        self.feature.add(self.conv(256, add_pooling=True))\n",
    "        self.feature.add(self.conv(512))\n",
    "        self.feature.add(self.conv(512))\n",
    "        self.feature.add(self.conv(512, add_pooling=True))\n",
    "\n",
    "        self.feature.add(self.conv(512))\n",
    "        self.feature.add(self.conv(512))\n",
    "        self.feature.add(self.conv(512, add_pooling=True))\n",
    "        self.feature.add(mx.gluon.nn.GlobalAvgPool2D())\n",
    "        self.feature.add(mx.gluon.nn.Flatten())\n",
    "\n",
    "        self.feature.add(mx.gluon.nn.Dense(4096, activation=\"relu\"))\n",
    "        self.feature.add(mx.gluon.nn.BatchNorm())\n",
    "\n",
    "        self.feature.add(mx.gluon.nn.Dense(4096, activation=\"relu\"))\n",
    "        self.feature.add(mx.gluon.nn.BatchNorm())\n",
    "\n",
    "        self.feature.add(mx.gluon.nn.Dropout(0.5))\n",
    "        # 这个简单的机构是分类器\n",
    "        self.pred = mx.gluon.nn.Dense(100)\n",
    "\n",
    "    def conv(self, filters, add_pooling=False):\n",
    "        # 模型大量使用重复模块构建，\n",
    "        # 这里将重复模块提取出来，简化模型构建过程\n",
    "        model = mx.gluon.nn.HybridSequential()\n",
    "        model.add(mx.gluon.nn.Conv2D(filters, 3, padding=1, activation=\"relu\"))\n",
    "        model.add(mx.gluon.nn.BatchNorm())\n",
    "\n",
    "        if add_pooling:\n",
    "            model.add(mx.gluon.nn.MaxPool2D(strides=2))\n",
    "        return model\n",
    "\n",
    "    def hybrid_forward(self, F, x):\n",
    "        # call 用来定义模型各个结构之间的运算关系\n",
    "\n",
    "        x = self.feature(x)\n",
    "        return self.pred(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上面使用的是 gluon api 相比于 gluon api，mxnet 还有纯符号和纯静态图的方式，但是不如 gluon api 方便，就像 keras 之于 Tensorflow， gluon api 也是 mxnet 社区主推的方式。\n",
    "\n",
    "> HybridBlock 这种带 Hybrid 前缀的模块，底层可以编译成静态图，速度快一些，可以尽量多用一下。不带Hybrid 前缀的模块使用起来不太一样。\n",
    "\n",
    "实例化一个模型看看："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MyModel(\n",
       "  (feature): HybridSequential(\n",
       "    (0): HybridSequential(\n",
       "      (0): Conv2D(None -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))\n",
       "      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)\n",
       "    )\n",
       "    (1): HybridSequential(\n",
       "      (0): Conv2D(None -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))\n",
       "      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)\n",
       "      (2): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)\n",
       "    )\n",
       "    (2): HybridSequential(\n",
       "      (0): Conv2D(None -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))\n",
       "      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)\n",
       "    )\n",
       "    (3): HybridSequential(\n",
       "      (0): Conv2D(None -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))\n",
       "      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)\n",
       "      (2): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)\n",
       "    )\n",
       "    (4): HybridSequential(\n",
       "      (0): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))\n",
       "      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)\n",
       "    )\n",
       "    (5): HybridSequential(\n",
       "      (0): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))\n",
       "      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)\n",
       "    )\n",
       "    (6): HybridSequential(\n",
       "      (0): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))\n",
       "      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)\n",
       "      (2): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)\n",
       "    )\n",
       "    (7): HybridSequential(\n",
       "      (0): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))\n",
       "      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)\n",
       "    )\n",
       "    (8): HybridSequential(\n",
       "      (0): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))\n",
       "      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)\n",
       "    )\n",
       "    (9): HybridSequential(\n",
       "      (0): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))\n",
       "      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)\n",
       "      (2): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)\n",
       "    )\n",
       "    (10): HybridSequential(\n",
       "      (0): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))\n",
       "      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)\n",
       "    )\n",
       "    (11): HybridSequential(\n",
       "      (0): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))\n",
       "      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)\n",
       "    )\n",
       "    (12): HybridSequential(\n",
       "      (0): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))\n",
       "      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)\n",
       "      (2): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)\n",
       "    )\n",
       "    (13): GlobalAvgPool2D(size=(1, 1), stride=(1, 1), padding=(0, 0), ceil_mode=True, global_pool=True, pool_type=avg, layout=NCHW)\n",
       "    (14): Flatten\n",
       "    (15): Dense(None -> 4096, Activation(relu))\n",
       "    (16): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)\n",
       "    (17): Dense(None -> 4096, Activation(relu))\n",
       "    (18): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)\n",
       "    (19): Dropout(p = 0.5, axes=())\n",
       "  )\n",
       "  (pred): Dense(None -> 100, linear)\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ctx = mx.gpu(0)\n",
    "model = MyModel()\n",
    "model.initialize(ctx=ctx, init=mx.initializer.Xavier())\n",
    "model.hybridize()\n",
    "\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型训练代码-训练相关部分\n",
    "要训练模型，我们还需要定义损失，优化器等。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_object = mx.gluon.loss.SoftmaxCrossEntropyLoss()\n",
    "optimizer = mx.gluon.Trainer(model.collect_params(), mx.optimizer.Adam())  # 优化器有些参数可以设置\n",
    "\n",
    "train_accuracy = mx.metric.Accuracy()\n",
    "val_accuracy = mx.metric.Accuracy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time  # 模型训练的过程中手动追踪一下模型的训练速度"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "因为模型整个训练过程一般是一个循环往复的过程，所以经常性的保存重启模型训练中间过程是有必要的。\n",
    "这里我们一个ckpt保存了两份，便于中断模型的重新训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "if os.path.exists(\"model.params\"):\n",
    "    # 检查 checkpoint 是否存在\n",
    "    # 如果存在，则加载 checkpoint\n",
    "    model.load_parameters(\"model.params\")\n",
    "\n",
    "    # 这里是一个比较生硬的方式，其实还可以观察之前训练的过程，\n",
    "    # 手动选择准确率最高的某次 checkpoint 进行加载。\n",
    "    print(\"model lodaded\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 Loss 6.9573759181431365, Acc 1.042857142857143, Val Loss 15.420175067138672, Val Acc 1.68\n",
      "Speed train 334.933801136163imgs/s val 791.7329888468499imgs/s\n",
      "Epoch 1 Loss 6.7644785858154295, Acc 1.1114285714285714, Val Loss 21.39283991394043, Val Acc 0.9400000000000001\n",
      "Speed train 344.24428987859835imgs/s val 1005.8295516339732imgs/s\n",
      "Epoch 2 Loss 6.427109525844029, Acc 1.1142857142857143, Val Loss 1441.4082737342835, Val Acc 1.46\n",
      "Speed train 350.2307806800212imgs/s val 1025.2414031403846imgs/s\n",
      "Epoch 3 Loss 5.979110830688477, Acc 1.1114285714285714, Val Loss 402112.5587174507, Val Acc 1.08\n",
      "Speed train 343.90973606373484imgs/s val 1007.538116122915imgs/s\n",
      "Epoch 4 Loss 5.7047360944475445, Acc 1.0114285714285716, Val Loss 28.884018599700926, Val Acc 1.16\n",
      "Speed train 346.9289205956802imgs/s val 1042.5033657692259imgs/s\n",
      "Epoch 5 Loss 5.429534565952846, Acc 1.1828571428571428, Val Loss 25.10572717514038, Val Acc 1.06\n",
      "Speed train 346.55915907021785imgs/s val 1042.9843522905003imgs/s\n",
      "Epoch 6 Loss 5.1636578002929685, Acc 1.18, Val Loss 292.2765090805054, Val Acc 1.26\n",
      "Speed train 346.677920645666imgs/s val 1041.1619895883405imgs/s\n",
      "Epoch 7 Loss 4.944499089268276, Acc 1.4000000000000001, Val Loss 8.26984390258789, Val Acc 1.66\n",
      "Speed train 345.7027344018046imgs/s val 1041.9690306718644imgs/s\n",
      "Epoch 8 Loss 4.776889256722587, Acc 1.4857142857142858, Val Loss 5.135510383605957, Val Acc 1.9\n",
      "Speed train 347.51812408602746imgs/s val 1055.2600700466792imgs/s\n",
      "Epoch 9 Loss 4.627297317940848, Acc 2.2885714285714287, Val Loss 19.935467903900147, Val Acc 3.36\n",
      "Speed train 345.4781099199051imgs/s val 1012.3481105170373imgs/s\n",
      "Epoch 10 Loss 4.379766181509836, Acc 3.842857142857143, Val Loss 10.535897972106934, Val Acc 6.4399999999999995\n",
      "Speed train 346.03842686704553imgs/s val 1046.8390845908468imgs/s\n",
      "Epoch 11 Loss 3.685629348100935, Acc 11.297142857142857, Val Loss 52.809229167175296, Val Acc 19.12\n",
      "Speed train 348.1024735895058imgs/s val 1005.0561635670016imgs/s\n",
      "Epoch 12 Loss 2.9683764310564316, Acc 24.19142857142857, Val Loss 96.89430379104614, Val Acc 34.64\n",
      "Speed train 347.82284815108676imgs/s val 1015.7570915242692imgs/s\n",
      "Epoch 13 Loss 2.32016633845738, Acc 39.92285714285714, Val Loss 1393.9259933712005, Val Acc 48.64\n",
      "Speed train 347.1405956344324imgs/s val 1002.6255219358517imgs/s\n",
      "Epoch 14 Loss 1.7546388859340123, Acc 53.98285714285714, Val Loss 405771.5952173068, Val Acc 62.68\n",
      "Speed train 347.7790321353627imgs/s val 1016.1043531852587imgs/s\n",
      "Epoch 15 Loss 1.4703079643249513, Acc 61.34285714285714, Val Loss 55974.54370078631, Val Acc 67.4\n",
      "Speed train 341.66829982471imgs/s val 996.3958929419091imgs/s\n",
      "Epoch 16 Loss 1.0772937241145544, Acc 71.44857142857143, Val Loss 175934832.46710944, Val Acc 74.83999999999999\n",
      "Speed train 333.6240690789895imgs/s val 988.0520794326872imgs/s\n",
      "Epoch 17 Loss 0.9253648305075509, Acc 75.28285714285714, Val Loss 24513272.186748803, Val Acc 73.38\n",
      "Speed train 324.7226062593469imgs/s val 964.7906383539004imgs/s\n",
      "Epoch 18 Loss 0.817839416544778, Acc 77.78, Val Loss 21546954.254889924, Val Acc 78.3\n",
      "Speed train 320.69584017889264imgs/s val 945.9076133097823imgs/s\n",
      "Epoch 19 Loss 0.869460376398904, Acc 76.63142857142857, Val Loss 8773747185.209038, Val Acc 78.72\n",
      "Speed train 314.57207850646853imgs/s val 936.1137303334394imgs/s\n"
     ]
    }
   ],
   "source": [
    "EPOCHS = 20\n",
    "for epoch in range(EPOCHS):\n",
    "\n",
    "    train_loss = 0\n",
    "    train_samples = 0\n",
    "    train_accuracy.reset()\n",
    "    val_accuracy.reset()\n",
    "\n",
    "    val_loss = 0\n",
    "    val_samples = 0\n",
    "\n",
    "    start = time.time()\n",
    "    for imgs, labels in img_train:\n",
    "        imgs = imgs.as_in_context(ctx)\n",
    "        labels = labels.as_in_context(ctx)\n",
    "\n",
    "        with mx.autograd.record():\n",
    "            preds = model(imgs)\n",
    "            loss = loss_object(preds, labels)\n",
    "        loss.backward()\n",
    "\n",
    "        optimizer.step(batch_size)\n",
    "\n",
    "        train_loss += loss.sum().asscalar()\n",
    "        train_accuracy.update(labels, preds)\n",
    "\n",
    "        train_samples += imgs.shape[0]\n",
    "        mx.nd.waitall()\n",
    "\n",
    "    train_samples_per_second = train_samples / (time.time() - start)\n",
    "\n",
    "    start = time.time()\n",
    "    for imgs, labels in img_val:\n",
    "        imgs = imgs.as_in_context(ctx)\n",
    "        labels = labels.as_in_context(ctx)\n",
    "\n",
    "        preds = model(imgs)\n",
    "        loss = loss_object(preds, labels)\n",
    "\n",
    "        val_loss += loss.sum().asscalar()\n",
    "        val_accuracy.update(labels, preds)\n",
    "\n",
    "        val_samples += imgs.shape[0]\n",
    "        mx.nd.waitall()\n",
    "\n",
    "    val_samples_per_second = val_samples / (time.time() - start)\n",
    "\n",
    "    print(\n",
    "        \"Epoch {} Loss {}, Acc {}, Val Loss {}, Val Acc {}\".format(\n",
    "            epoch,\n",
    "            train_loss / train_samples,\n",
    "            train_accuracy.get()[1] * 100,\n",
    "            val_loss / val_samples,\n",
    "            val_accuracy.get()[1] * 100,\n",
    "        )\n",
    "    )\n",
    "    print(\n",
    "        \"Speed train {}imgs/s val {}imgs/s\".format(\n",
    "            train_samples_per_second, val_samples_per_second\n",
    "        )\n",
    "    )\n",
    "\n",
    "    model.save_parameters(\"model.params\")\n",
    "    model.save_parameters(\"model-{:04d}.params\".format(epoch))\n",
    "\n",
    "    # 每个 epoch 保存一下模型，需要注意每次\n",
    "    # 保存要用一个不同的名字，不然会导致覆盖，\n",
    "    # 同时还要关注一下磁盘空间占用，防止太多\n",
    "    # chekcpoint 占满磁盘空间导致错误。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 一些技巧\n",
    "\n",
    "因为这里定义的模型比较大，同时训练的数据也比较多，每个 epoch 用时较长，因此，如果代码有 bug 的话，经过一次 epoch 再去 debug 效率比较低。\n",
    "\n",
    "这种情况下，我们使用的数据生成过程又是自己手动指定数据数量的，因此可以尝试缩减模型规模，定义小一些的数据集来快速验证代码。在这个例子里，我们可以通过注释模型中的卷积和全连接层的代码来缩减模型尺寸，通过修改训练循环里面的数据数量来缩减数据数量。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练的速度很慢\n",
    "一开始训练速度比较慢，因为 mxnet 默认初始化方式是 uniform。\n",
    "\n",
    "这里改成了跟 tf 一样的 xavier （tf里面叫做 glorot_uniform）之后训练速度还是比较慢，要 10 个 epoch才能看到收敛。而且 TF 里面 20epochs 能达到 90% 的准确率，这里菜 76%，应该是哪里有什么问题，我再看看怎么解决。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
